import numpy as np
import torch
torch.manual_seed(2020)
from torch import nn
import torch.nn.functional as F

import pdb

def generate_total_sample(num_user, num_item):
    sample = []
    for i in range(num_user):
        sample.extend([[i,j] for j in range(num_item)])
    return np.array(sample)

def sigmoid(x):
    return 1.0 / (1 + np.exp(-x))

mse_func = lambda x,y: np.mean((x-y)**2)


class MF_BaseModel(nn.Module):
    def __init__(self, num_users, num_items, embedding_k=4, *args, **kwargs):
        super(MF_BaseModel, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num_users, self.embedding_k)
        self.H = torch.nn.Embedding(self.num_items, self.embedding_k)
        self.sigmoid = torch.nn.Sigmoid()
        self.xent_func = torch.nn.BCELoss()

    def forward(self, x, y, is_training=False):
        user_idx = torch.LongTensor(x).cuda()
        item_idx = torch.LongTensor(y).cuda()
        U_emb = self.W(user_idx)
        V_emb = self.H(item_idx)

        out = self.sigmoid(torch.sum(U_emb.mul(V_emb), 1))

        if is_training:
            return out, U_emb, V_emb, U_B, V_B, G_B
        else:
            return out

    def predict(self, x, y):
        pred = self.forward(x, y)
        return pred.detach().cpu()        

class Embedding_Sharing(nn.Module):
    
    def __init__(self, num_users, num_items, embedding_k=4, *args, **kwargs):
        super(Embedding_Sharing, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num_users, self.embedding_k)
        self.H = torch.nn.Embedding(self.num_items, self.embedding_k)
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()

        self.xent_func = torch.nn.BCELoss()


    def forward(self, x, y, is_training=True):
        user_idx = torch.LongTensor(x).cuda()
        item_idx = torch.LongTensor(y).cuda()
        U_emb = self.W(user_idx)
        V_emb = self.H(item_idx)

        if is_training:
            return torch.squeeze(U_emb), torch.squeeze(V_emb) 
    
class Oneside_MLP(nn.Module):
    def __init__(self, num, embedding_k = 4):
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.xent_func = torch.nn.BCELoss()    
        self.num = num
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num, self.embedding_k)
        self.linear_1 = torch.nn.Linear(embedding_k, embedding_k, bias = False)
        self.linear_2 = torch.nn.Linear(embedding_k, 1, bias = True)
        
    def forward(self, x):
        x = self.W(torch.LongTensor(x).cuda())
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.sigmoid(x)
        
        return torch.squeeze(x)    
    
class MF_UDR(nn.Module):
    def __init__(self, num_users, num_items, embedding_k=16, l2_reg_lambda = 1e-4, *args, **kwargs):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num_users, self.embedding_k)
        self.H = torch.nn.Embedding(self.num_items, self.embedding_k)

        self.l2_reg_lambda = l2_reg_lambda
        
        self.linear_1 = torch.nn.Linear(embedding_k, embedding_k, bias = True)
        self.linear_2 = torch.nn.Linear(embedding_k, 1, bias = False)
        
        self.propensity_model = Embedding_Sharing(num_users, num_items, embedding_k)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.xent_func = torch.nn.BCELoss()

    def fit(self, x, y, rating, r_hat, obs, L = 5, gamma = 0.02,
        num_epoch=1000, batch_size=20, lr=0.05, lamb=1e-4, 
        tol=1e-4, verbose=False):

        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)
        last_loss = 1e9

        num_sample = self.num_users
        total_batch = num_sample // batch_size

        early_stop = 0
        for epoch in range(num_epoch):
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)
            epoch_loss = 0

            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size*idx:(idx+1)*batch_size]
                sub_u = x[selected_idx, :].reshape(-1)                
                sub_i = y[selected_idx, :].reshape(-1)
                sub_y = rating[selected_idx, :].reshape(-1)

                sub_y = torch.Tensor(sub_y).cuda()

                u_emb, v_emb = self.propensity_model.forward(sub_u, sub_i)

                pred = self.sigmoid(torch.sum(u_emb.mul(v_emb), 1))

                assert pred.shape == sub_y.shape
                xent_loss = nn.MSELoss()(pred, sub_y)

                prop_inv = (1/torch.clip(pred, gamma, 1)).reshape([len(selected_idx), self.num_items])
                sub_obs = obs[selected_idx, :].cuda()
                sub_r_hat = r_hat[selected_idx, :].cuda()
                                
                u_emb, _ = self.propensity_model.forward(all_idx, [0])
                
                u_emb_active = self.linear_1(u_emb)
                u_emb_active = torch.tanh(u_emb_active).squeeze()                
                alpha = nn.Softmax(dim = 0)(self.linear_2(u_emb_active).squeeze())
                
                sub_alpha = alpha[selected_idx]
                
                constrain1 = torch.sum(sub_obs * prop_inv * (-torch.log(sub_r_hat + 1e-6) + torch.log(1 - sub_r_hat + 1e-6)), dim = 1)
                constrain2 = torch.sum((-torch.log(sub_r_hat)) + torch.log(1 - sub_r_hat), dim = 1)

                contrain_loss = torch.sum(sub_alpha * ((constrain1 - constrain2) ** 2))/(len(selected_idx) * self.num_items)

                loss = xent_loss + L*contrain_loss
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.detach().cpu().numpy()
            
            relative_loss_div = (last_loss-epoch_loss)/(last_loss+1e-10)
            if  relative_loss_div < tol:
                if early_stop > 3:
                    print("[MF] epoch:{}".format(epoch))
                    break
                early_stop += 1
                
            last_loss = epoch_loss

            if epoch % 10 == 0 and verbose:
                print("[MF] epoch:{}, xent:{}".format(epoch, epoch_loss))

            if epoch == num_epoch - 1:
                print("[MF] Reach preset epochs, it seems does not converge.")

    def predict(self, x, y):
        u_emb, v_emb = self.propensity_model.forward(x, y)
        pred = 1/self.sigmoid(torch.sum(u_emb.mul(v_emb), 1))
        return pred.detach().cpu().numpy()
        

        
class MF_IDR(nn.Module):
    def __init__(self, num_users, num_items, embedding_k=16, l2_reg_lambda = 1e-4, *args, **kwargs):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num_users, self.embedding_k)
        self.H = torch.nn.Embedding(self.num_items, self.embedding_k)
        self.l2_reg_lambda = l2_reg_lambda
        
        self.linear_1 = torch.nn.Linear(embedding_k, embedding_k, bias = True)
        self.linear_2 = torch.nn.Linear(embedding_k, 1, bias = False)
        
        self.propensity_model = Embedding_Sharing(num_users, num_items, embedding_k)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.xent_func = torch.nn.BCELoss()

    def fit(self, x, y, rating, r_hat, obs, L = 5, gamma = 0.02,
        num_epoch=1000, batch_size=20, lr=0.05, lamb=1e-4, 
        tol=1e-4, verbose=False):

        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)
        last_loss = 1e9

        num_sample = self.num_items
        total_batch = num_sample // batch_size

        early_stop = 0
        for epoch in range(num_epoch):
            all_idx = np.arange(num_sample)
            np.random.shuffle(all_idx)
            epoch_loss = 0

            for idx in range(total_batch):
                # mini-batch training
                selected_idx = all_idx[batch_size*idx:(idx+1)*batch_size]
                sub_u = x[:, selected_idx].reshape(-1)                
                sub_i = y[:, selected_idx].reshape(-1)
                sub_y = rating[:, selected_idx].reshape(-1)
                sub_y = torch.Tensor(sub_y).cuda()

                u_emb, v_emb = self.propensity_model.forward(sub_u, sub_i)

                pred = self.sigmoid(torch.sum(u_emb.mul(v_emb), 1))

                assert pred.shape == sub_y.shape
                xent_loss = nn.MSELoss()(pred, sub_y)

                prop_inv = (1/torch.clip(pred, gamma, 1)).reshape([self.num_users, len(selected_idx)])
                sub_obs = obs[:, selected_idx].cuda()
                sub_r_hat = r_hat[:, selected_idx].cuda()
                                
                _, i_emb = self.propensity_model.forward([0], all_idx)

                i_emb_active = self.linear_1(i_emb)
                i_emb_active = torch.tanh(i_emb_active).squeeze()                

                alpha = nn.Softmax(dim = 0)(self.linear_2(i_emb_active).squeeze())
                
                sub_alpha = alpha[selected_idx]
                             
                constrain1 = torch.sum(sub_obs * prop_inv * (-torch.log(sub_r_hat) + torch.log(1 - sub_r_hat)), dim = 0)
                constrain2 = torch.sum((-torch.log(sub_r_hat)) + torch.log(1 - sub_r_hat), dim = 0)
                
                contrain_loss = torch.sum(sub_alpha * ((constrain1 - constrain2) ** 2))/(len(selected_idx) * self.num_users)

                loss = xent_loss + L*contrain_loss

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.detach().cpu().numpy()
            
            relative_loss_div = (last_loss-epoch_loss)/(last_loss+1e-10)
            if  relative_loss_div < tol:
                if early_stop > 3:
                    print("[MF] epoch:{}".format(epoch))
                    break
                early_stop += 1
                
            last_loss = epoch_loss

            if epoch % 10 == 0 and verbose:
                print("[MF] epoch:{}, xent:{}".format(epoch, epoch_loss))

            if epoch == num_epoch - 1:
                print("[MF] Reach preset epochs, it seems does not converge.")

    def predict(self, x, y):
        u_emb, v_emb = self.propensity_model.forward(x, y)
        pred = 1/self.sigmoid(torch.sum(u_emb.mul(v_emb), 1))
        return pred.detach().cpu().numpy()      
    
    
class MF_UIDR(nn.Module):
    def __init__(self, num_users, num_items, embedding_k=4, l2_reg_lambda = 1e-4, *args, **kwargs):
        super().__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_k = embedding_k
        self.W = torch.nn.Embedding(self.num_users, self.embedding_k)
        self.H = torch.nn.Embedding(self.num_items, self.embedding_k)
        self.l2_reg_lambda = l2_reg_lambda
        
        self.linear_1_u = torch.nn.Linear(embedding_k, embedding_k, bias = True)
        self.linear_2_u = torch.nn.Linear(embedding_k, 1, bias = False)

        self.linear_1_i = torch.nn.Linear(embedding_k, embedding_k, bias = True)
        self.linear_2_i = torch.nn.Linear(embedding_k, 1, bias = False)        
        
        self.propensity_model = Embedding_Sharing(num_users, num_items, embedding_k)
        
        self.sigmoid = torch.nn.Sigmoid()
        self.xent_func = torch.nn.BCELoss()

    def fit(self, x, y, rating, r_hat, obs, L = 10, gamma = 0.02,
        num_epoch=1000, batch_size_u=20, lr=0.05, lamb=1e-4, 
        tol=1e-4, verbose=False):

        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)
        last_loss = 1e9

        num_sample_u = self.num_users
        batch_size_i = batch_size_u * self.num_items // self.num_users
        total_batch = self.num_users // batch_size_u

        early_stop = 0
        for epoch in range(num_epoch):
            
            all_idx_u = np.arange(self.num_users)
            all_idx_i = np.arange(self.num_items)
            
            np.random.shuffle(all_idx_u)
            np.random.shuffle(all_idx_i)
            epoch_loss = 0

            for idx in range(total_batch):
                # mini-batch training
                selected_idx_u = all_idx_u[batch_size_u*idx:(idx+1)*batch_size_u]
                selected_idx_i = all_idx_i[batch_size_i*idx:(idx+1)*batch_size_i]
                
                sub_u_u = x[selected_idx_u, :].reshape(-1)                
                sub_u_i = y[selected_idx_u, :].reshape(-1)
                
                sub_i_u = x[:, selected_idx_i].reshape(-1)                
                sub_i_i = y[:, selected_idx_i].reshape(-1)                

                sub_u_y = rating[selected_idx_u, :].reshape(-1)
                sub_i_y = rating[:, selected_idx_i].reshape(-1)
                sub_u_y = torch.Tensor(sub_u_y).cuda()
                sub_i_y = torch.Tensor(sub_i_y).cuda()

                u_emb_u, v_emb_u = self.propensity_model.forward(sub_u_u, sub_u_i)

                pred_u = self.sigmoid(torch.sum(u_emb_u.mul(v_emb_u), 1))

                xent_loss_u = nn.MSELoss(reduction = 'sum')(pred_u, sub_u_y)

                u_emb_i, v_emb_i = self.propensity_model.forward(sub_i_u, sub_i_i)
                pred_i = self.sigmoid(torch.sum(u_emb_i.mul(v_emb_i), 1))        
                xent_loss_i = nn.MSELoss(reduction = 'sum')(pred_i, sub_i_y)
                xent_loss = (xent_loss_u + xent_loss_i)/(batch_size_u * self.num_items)
                
                prop_inv_u = (1/torch.clip(pred_u, gamma, 1)).reshape([len(selected_idx_u), self.num_items])
                prop_inv_i = (1/torch.clip(pred_i, gamma, 1)).reshape([self.num_users, len(selected_idx_i)])
                
                sub_obs_u = obs[selected_idx_u, :].cuda()
                sub_r_hat_u = r_hat[selected_idx_u, :].cuda()
                sub_obs_i = obs[:, selected_idx_i].cuda()
                sub_r_hat_i = r_hat[:, selected_idx_i].cuda()
                
                u_emb, _ = self.propensity_model.forward(all_idx_u, [0])   

                u_emb_active = self.linear_1_u(u_emb)
                u_emb_active = torch.tanh(u_emb_active).squeeze()                

                alpha_u = nn.Softmax(dim = 0)(self.linear_2_u(u_emb_active).squeeze())
                
                sub_alpha_u = alpha_u[selected_idx_u]
                            
                      
                constrain1_u = torch.sum(sub_obs_u * prop_inv_u * (-torch.log(sub_r_hat_u) + torch.log(1 - sub_r_hat_u)), dim = 1)
                constrain2_u = torch.sum((-torch.log(sub_r_hat_u)) + torch.log(1 - sub_r_hat_u), dim = 1)
                
                contrain_loss_u = torch.sum(sub_alpha_u * ((constrain1_u - constrain2_u) ** 2))/(len(selected_idx_u) * self.num_items)
                
                _, i_emb = self.propensity_model.forward([0], all_idx_i)
                
                i_emb_active = self.linear_1_i(i_emb)
                i_emb_active = torch.tanh(i_emb_active).squeeze()                

                alpha_i = nn.Softmax(dim = 0)(self.linear_2_i(i_emb_active).squeeze())
                
                sub_alpha_i = alpha_i[selected_idx_i]
                

                constrain1_i = torch.sum(sub_obs_i * prop_inv_i * (-torch.log(sub_r_hat_i) + torch.log(1 - sub_r_hat_i)), dim = 0)
                constrain2_i = torch.sum((-torch.log(sub_r_hat_i)) + torch.log(1 - sub_r_hat_i), dim = 0)
                
                contrain_loss_i = torch.sum(sub_alpha_i * ((constrain1_i - constrain2_i) ** 2))/(len(selected_idx_i) * self.num_users)    
                contrain_loss = contrain_loss_u + contrain_loss_i
                
                loss = xent_loss + L * contrain_loss
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                epoch_loss += loss.detach().cpu().numpy()
            
            relative_loss_div = (last_loss-epoch_loss)/(last_loss+1e-10)
            if  relative_loss_div < tol:
                if early_stop > 5:
                    print("[MF] epoch:{}".format(epoch))
                    break
                early_stop += 1
                
            last_loss = epoch_loss

            if epoch % 10 == 0 and verbose:
                print("[MF] epoch:{}, xent:{}".format(epoch, epoch_loss))

            if epoch == num_epoch - 1:
                print("[MF] Reach preset epochs, it seems does not converge.")

    def predict(self, x, y):
        u_emb, v_emb = self.propensity_model.forward(x, y)
        pred = 1/self.sigmoid(torch.sum(u_emb.mul(v_emb), 1))
        return pred.detach().cpu().numpy()    